Skip to content

Disease classification on PlantVillage

In this chapter, we will design a CNN to perform the plant disease classification task on the PlantVillage dataset. This includes several steps, which are outlined below:

  1. Explore and preprocess the PlantVillage dataset
  2. Design an isotropic CNN architecture
  3. Train the CNN on the PlantVillage dataset
  4. Analyze accuracy of the CNN model from the angle of hierarchical confusion matrix

the PlantVillage dataset

The PlantVillage dataset is a collection of 54,305 images of 14 different plant species, belonging to 38 classes, 12 of which are healthy, 26 of which are diseased.

The dataset was created by the Penn State College of Agricultural Sciences and the International Institute of Tropical Agriculture as a resource for research and development of computer vision-based plant disease detection systems. The images in the dataset were collected from various sources, including research institutions and citizen scientists, and represent a wide variety of plant species and disease types.

The plants include fruits such as apple, blueberry, cherry, grape, orange, peach, raspberry, squash, strawberry and crops such as corn, soybean and vegetables such as pepper bell, potato, tomato. Each plant is in healthy status or in disease such as scab, rot, rust, and so on.

import pandas as pd
df = pd.read_csv('data/cls_count.csv')
df[['Plant', 'Disease', 'Count']]
Plant Disease Count
0 Apple Apple_scab 630
1 Apple Black_rot 621
2 Apple Cedar_apple_rust 275
3 Apple healthy 1645
4 Blueberry healthy 1502
5 Cherry Powdery_mildew 1052
6 Cherry healthy 854
7 Corn Cercospora_leaf_spot Gray_leaf_spot 513
8 Corn Common_rust 1192
9 Corn Northern_Leaf_Blight 985
10 Corn healthy 1162
11 Grape Black_rot 1180
12 Grape Esca_(Black_Measles) 1383
13 Grape Leaf_blight_(Isariopsis_Leaf_Spot) 1076
14 Grape healthy 423
15 Orange Haunglongbing_(Citrus_greening) 5507
16 Peach Bacterial_spot 2297
17 Peach healthy 360
18 Pepper,_bell Bacterial_spot 997
19 Pepper,_bell healthy 1478
20 Potato Early_blight 1000
21 Potato Late_blight 1000
22 Potato healthy 152
23 Raspberry healthy 371
24 Soybean healthy 5090
25 Squash Powdery_mildew 1835
26 Strawberry Leaf_scorch 1109
27 Strawberry healthy 456
28 Tomato Bacterial_spot 2127
29 Tomato Early_blight 1000
30 Tomato Late_blight 1909
31 Tomato Leaf_Mold 952
32 Tomato Septoria_leaf_spot 1771
33 Tomato Spider_mites Two-spotted_spider_mite 1676
34 Tomato Target_Spot 1404
35 Tomato Tomato_Yellow_Leaf_Curl_Virus 5357
36 Tomato Tomato_mosaic_virus 373
37 Tomato healthy 1591

The number of images of all the different types of plants are different with each other. Such a skewed distribution of the number of images in a dataset is called imbalanced. A imbalanced dataset is more difficult to train then a balanced dataset.

xticks = range(38)
ax = df.plot.bar(
    x='Disease', y='Count', 
    title='Imbalanced distribution of the counts of images',
    xlabel='Classes', xticks=xticks, 
    figsize=(10,5))
legend = ax.legend(loc=2)

Next, let us show 38 images, one for each category.

import os
root_dir = "data/plantvillage/"
samples = []
classes = os.listdir(root_dir)
for cls in classes:
    cls_path = os.path.join(root_dir, cls)
    if os.path.isdir(cls_path):
        for img_name in os.listdir(cls_path):
            img_path = os.path.join(cls_path, img_name)
            samples.append(img_path)
from PIL import Image
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(figsize=(12., 20.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(6, 7),  # creates 6x7 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )
for ax in grid:
    ax.axis("off")

for ax, img_path in zip(grid, samples):
    img = Image.open(img_path)
    ax.axis("off")
    ax.imshow(img)

A lightweight isotropic CNN architecture

CNNs can be divided into two types: isometric and pyramidical. Isotropic CNNs are a type of CNN that have equal size and shape for all layers throughout the network, while pyramidical CNNs use layers with varying sizes and shapes. The difference between them is illustrated by the figure below.

isotropic_vs_pyramidical = Image.open("data/isotropic_vs_pyramidical.PNG")
ax = plt.figure(figsize=(10,5))
plt.imshow(isotropic_vs_pyramidical)
plt.axis("off")
plt.show()

Isotropic CNNs emerged partially inspired by the state-of-the-art attention-based transformer architectures in computer vision that are isotropic architectures. Compared to pyramidical architectures, recent research discovers that isotropic architectures may improve performance or even meet state-of-the-art performance with a lot lighter layers.

We proposed a lightweight isotropic CNN, FoldNet, which achieved 99.84% accuracy in disease classification task on the PlantVillage dataset.

from IPython.display import Image, display
display(Image('data/foldnet_arch.png', width="80%"))

Hierarchical Confusion Matrix of PlantVillage

A confusion matrix is a visualization tool in machine learning to help people to evaluate the performance of a classification model. It is a tabular layout that compares predicted class labels against actual class labels over all data instances. The rows of the matrix represent the actual classes, while the columns represent the predicted classes. By analyzing the confusion matrix, we can determine how well the model is able to distinguish between different classes, as well as which classes are most often confused with one another. Popular performance metrics, such as accuracy, precision, recall, F-1 score could be derived from the confusion matrix.

The PlantVillage dataset has a tree-like hierarchical structure with three levels. The root node is the overall category, plant. The first level is the 14 specific plant species. The second level is the healthy or disease status of the particular plant. Thus we use hierarchical confusion matrix to capture the hierarchical structure in the dataset.

The following is an interactive widget to visualize the hierarchical confusion matrix of the FoldNet model when evaluating on the 10,861 testing images of the PlantVillage dataset.

The FoldNet model achieves 99.84% accuracy, with only 17 images are classified incorrectly. After quantitatively analyzing these 17 images, we find three interesting points need to be noted:

First, compared to incorrect classification within the same species, incorrect classification across species are very rare. Only 5 images are incorrectly identified as images of different plant species, while the other 12 images are identified correctly as to their species, even though incorrectly as to their disease status. This reflects the robustness of the FoldNet model, which can correctly predict the species of a image even if its prediction of the image's disease status is wrong.

Second, the 12 images that are incorrectly classified within the same species belong to two species, corn and tomato, rather than uniformly distributed in all the 14 species. 4 of the 12 images belong to corn, and the other 8 images belong to tomato. This reflects the complexity of the images of corn and tomato.

Third, in the 17 images that are classified incorrectly, several images are wrong in ground truth or captured in a extreme situation. For example, the first 'Cherry Healthy' image is actually field background; two "Tomato Late Blight" images have a very small foreground and a very large background.

display(Image('data/falsely_predicted_across_species.png', width="100%"))
display(Image('data/falsely_predicted_inner_species.png', width="100%"))